[Kernel][Comms] feat: add custom all-gather kernels#1524
Conversation
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request introduces custom all-gather kernels to optimize performance, particularly for context parallelism. The implementation is a good start, but I've identified several critical issues that need to be addressed. These include function signature mismatches that will cause build failures, a critical bug in handling variable-sized inputs due to incorrect use of std::set, and potential integer overflows from std::accumulate. Please review the detailed comments for fixes.
| } | ||
|
|
||
| private: | ||
| std::set<int> mGroup; |
There was a problem hiding this comment.
Using std::set for mGroup introduces a critical bug and is inefficient. std::set sorts its elements, which will break the correspondence between the ranks and the sizes vector if the input group_ranks is not sorted. This leads to incorrect data gathering when sizes are variable. Additionally, iterating over the set using std::advance inside run_list is inefficient.
Please change mGroup to be a std::vector<int> to preserve the order of ranks and allow for efficient indexing. This requires changes in multiple places:
- Change
mGroup's type tostd::vector<int>here. - Update
AllgatherOpconstructor (lines 72-74) to acceptstd::vector<int>. - Update
init_custom_ag(lines 148-154) to construct astd::vector<int>fromgroup_ranksand pass it toAllgatherOp. - Update the loop in
run_list(lines 118-121) to usemGroup[root_idx]instead ofstd::advance.
std::vector<int> mGroup;
| int64_t open_mem_handle(torch::Tensor& mem_handle); | ||
| void free_shared_buffer(int64_t buffer); | ||
|
|
||
| fptr_t init_custom_ag(const std::vector<int64_t>& group_ranks); |
There was a problem hiding this comment.
The signature of init_custom_ag is inconsistent with its implementation in custom_all_gather.cu and its declaration in custom_all_gather.cuh. The implementation expects a second argument nccl_comm_ptr of type int64_t, which is missing here. This will cause a linker error.
| fptr_t init_custom_ag(const std::vector<int64_t>& group_ranks); | |
| fptr_t init_custom_ag(const std::vector<int64_t>& group_ranks, int64_t nccl_comm_ptr); |
| TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _custom_ag), custom_ag) { | ||
| // Custom all-gather kernels | ||
| custom_ag.def( | ||
| "init_custom_ag(int[] group_ranks) -> int"); |
There was a problem hiding this comment.
The TorchScript signature for init_custom_ag is inconsistent with its C++ implementation. It's missing the nccl_comm_ptr argument. This will cause a compilation or runtime error. The signature should accept an additional integer for the communicator pointer.
| "init_custom_ag(int[] group_ranks) -> int"); | |
| "init_custom_ag(int[] group_ranks, int nccl_comm_ptr) -> int"); |
| std::all_of(sizes.value().begin(), sizes.value().end(), | ||
| [&sizes](int64_t size) { return size == sizes.value()[0]; }); | ||
|
|
||
| int64_t sum_sizes = sizes.has_value() ? std::accumulate(sizes.value().begin(), sizes.value().end(), 0, std::plus<>{}) : 0; |
There was a problem hiding this comment.
The initial value for std::accumulate is 0, which is an int. Since the sizes vector contains int64_t values, the sum could overflow an int if it exceeds INT_MAX. The accumulator's type is determined by the type of this initial value. To prevent overflow, please use an int64_t initial value.
int64_t sum_sizes = sizes.has_value() ? std::accumulate(sizes.value().begin(), sizes.value().end(), int64_t{0}, std::plus<>{}) : 0;
| AT_CUDA_CHECK(ncclAllGather(input.data_ptr(), output.mutable_data_ptr(), input.numel(), (*getDtypeMap())[type], | ||
| mNcclComm, stream)); | ||
| } else { | ||
| size_t numel_base = std::accumulate(outputShape.cbegin() + 1, outputShape.cend(), 1, std::multiplies<>{}); |
There was a problem hiding this comment.
The initial value for std::accumulate is 1, which is an int. The product of tensor dimensions can easily overflow an int. The accumulator's type is determined by this initial value. Please use a size_t initial value to prevent potential overflow, as the result is stored in a size_t.
size_t numel_base = std::accumulate(outputShape.cbegin() + 1, outputShape.cend(), size_t{1}, std::multiplies<>{});
We don't really use all-gather all that much, but for context parallel, all-gather is used quite a lot. This adds a fair bit of overhead when doing Context Parallelism, sometimes halving the speed. Currently WIP.
CP will land in #1521